{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!git clone https://github.com/pgmpy/pgmpy\n",
    "!cd pgmpy\n",
    "!python setup.py install"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pgmpy.factors.discrete import TabularCPD\n",
    "from pgmpy.models import BayesianModel\n",
    "\n",
    "student_model = BayesianModel([('D', 'G'),\n",
    "                               ('I', 'G'),\n",
    "                               ('G', 'L'),\n",
    "                               ('I', 'S')])\n",
    "\n",
    "grade_cpd = TabularCPD(\n",
    "    variable='G', # 节点名称\n",
    "    variable_card=3, # 节点取值个数\n",
    "    values=[[0.3, 0.05, 0.9, 0.5], # 该节点的概率表\n",
    "    [0.4, 0.25, 0.08, 0.3],\n",
    "    [0.3, 0.7, 0.02, 0.2]],\n",
    "    evidence=['I', 'D'], # 该节点的依赖节点\n",
    "    evidence_card=[2, 2] # 依赖节点的取值个数\n",
    ")\n",
    "\n",
    "difficulty_cpd = TabularCPD(\n",
    "            variable='D',\n",
    "            variable_card=2,\n",
    "            values=[[0.6, 0.4]]\n",
    ")\n",
    "\n",
    "intel_cpd = TabularCPD(\n",
    "            variable='I',\n",
    "            variable_card=2,\n",
    "            values=[[0.7, 0.3]]\n",
    ")\n",
    "\n",
    "letter_cpd = TabularCPD(\n",
    "            variable='L',\n",
    "            variable_card=2,\n",
    "            values=[[0.1, 0.4, 0.99],\n",
    "            [0.9, 0.6, 0.01]],\n",
    "            evidence=['G'],\n",
    "            evidence_card=[3]\n",
    ")\n",
    "\n",
    "sat_cpd = TabularCPD(\n",
    "            variable='S',\n",
    "            variable_card=2,\n",
    "            values=[[0.95, 0.2],\n",
    "            [0.05, 0.8]],\n",
    "            evidence=['I'],\n",
    "            evidence_card=[2]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "student_model.add_cpds(\n",
    "    grade_cpd, \n",
    "    difficulty_cpd,\n",
    "    intel_cpd,\n",
    "    letter_cpd,\n",
    "    sat_cpd\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<TabularCPD representing P(G:3 | I:2, D:2) at 0x2b642f7dd68>,\n",
       " <TabularCPD representing P(D:2) at 0x2b642f7dda0>,\n",
       " <TabularCPD representing P(I:2) at 0x2b642f509e8>,\n",
       " <TabularCPD representing P(L:2 | G:3) at 0x2b642f7de80>,\n",
       " <TabularCPD representing P(S:2 | I:2) at 0x2b642f7de10>]"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "student_model.get_cpds()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'D': {'D', 'G', 'L'}}"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "student_model.active_trail_nodes('D')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(G _|_ S | D, I)"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "student_model.local_independencies('G')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(D _|_ I, S)\n",
       "(D _|_ L | G)\n",
       "(D _|_ S | I)\n",
       "(D _|_ I | S)\n",
       "(D _|_ S | I, L)\n",
       "(D _|_ L, S | G, I)\n",
       "(D _|_ L | G, S)\n",
       "(D _|_ S | G, I, L)\n",
       "(D _|_ L | G, I, S)\n",
       "(G _|_ S | I)\n",
       "(G _|_ S | I, L)\n",
       "(G _|_ S | D, I)\n",
       "(G _|_ S | D, I, L)\n",
       "(I _|_ D)\n",
       "(I _|_ L | G)\n",
       "(I _|_ D | S)\n",
       "(I _|_ L | D, G)\n",
       "(I _|_ L | G, S)\n",
       "(I _|_ L | D, G, S)\n",
       "(L _|_ D, I, S | G)\n",
       "(L _|_ S | I)\n",
       "(L _|_ I, S | D, G)\n",
       "(L _|_ S | D, I)\n",
       "(L _|_ D, S | G, I)\n",
       "(L _|_ D, I | G, S)\n",
       "(L _|_ S | D, I, G)\n",
       "(L _|_ I | D, G, S)\n",
       "(L _|_ D | G, I, S)\n",
       "(S _|_ D)\n",
       "(S _|_ L | G)\n",
       "(S _|_ D, G, L | I)\n",
       "(S _|_ G, D | I, L)\n",
       "(S _|_ D, L | G, I)\n",
       "(S _|_ L | G, D)\n",
       "(S _|_ G, L | D, I)\n",
       "(S _|_ D | G, I, L)\n",
       "(S _|_ G | D, I, L)\n",
       "(S _|_ L | G, I, D)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "student_model.get_independencies()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<pgmpy.models.MarkovModel.MarkovModel at 0x2b642f7d588>"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "student_model.to_markov_model()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Finding Elimination Order: : 100%|█████████████████████████████████████████████████████| 4/4 [00:00<00:00, 1002.82it/s]\n",
      "Eliminating: I: 100%|███████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 267.39it/s]\n"
     ]
    }
   ],
   "source": [
    "# 进行贝叶斯推断\n",
    "from pgmpy.inference import VariableElimination\n",
    "student_infer = VariableElimination(student_model)\n",
    "prob_G = student_infer.query(variables=['G'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+------+----------+\n",
      "| G    |   phi(G) |\n",
      "+======+==========+\n",
      "| G(0) |   0.3620 |\n",
      "+------+----------+\n",
      "| G(1) |   0.2884 |\n",
      "+------+----------+\n",
      "| G(2) |   0.3496 |\n",
      "+------+----------+\n"
     ]
    }
   ],
   "source": [
    "print(prob_G)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Finding Elimination Order: : 100%|██████████████████████████████████████████████████████| 2/2 [00:00<00:00, 668.68it/s]\n",
      "Eliminating: L: 100%|███████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 286.58it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+------+----------+\n",
      "| G    |   phi(G) |\n",
      "+======+==========+\n",
      "| G(0) |   0.9000 |\n",
      "+------+----------+\n",
      "| G(1) |   0.0800 |\n",
      "+------+----------+\n",
      "| G(2) |   0.0200 |\n",
      "+------+----------+\n"
     ]
    }
   ],
   "source": [
    "prob_G = student_infer.query(\n",
    "            variables=['G'],\n",
    "            evidence={'I': 1, 'D': 0})\n",
    "print(prob_G)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "training for bayesian network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>D</th>\n",
       "      <th>I</th>\n",
       "      <th>G</th>\n",
       "      <th>L</th>\n",
       "      <th>S</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   D  I  G  L  S\n",
       "0  0  1  0  0  0\n",
       "1  0  0  1  1  0\n",
       "2  0  0  1  1  1\n",
       "3  1  0  1  1  0\n",
       "4  1  1  0  0  1"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 生成数据\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "raw_data = np.random.randint(low=0, high=2, size=(1000, 5))\n",
    "data = pd.DataFrame(raw_data, columns=['D', 'I', 'G', 'L', 'S'])\n",
    "data.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPD of D:\n",
      "+------+-------+\n",
      "| D(0) | 0.484 |\n",
      "+------+-------+\n",
      "| D(1) | 0.516 |\n",
      "+------+-------+\n",
      "CPD of G:\n",
      "+------+--------------------+--------------------+-------------------+--------------------+\n",
      "| D    | D(0)               | D(0)               | D(1)              | D(1)               |\n",
      "+------+--------------------+--------------------+-------------------+--------------------+\n",
      "| I    | I(0)               | I(1)               | I(0)              | I(1)               |\n",
      "+------+--------------------+--------------------+-------------------+--------------------+\n",
      "| G(0) | 0.4666666666666667 | 0.5040983606557377 | 0.508130081300813 | 0.4962962962962963 |\n",
      "+------+--------------------+--------------------+-------------------+--------------------+\n",
      "| G(1) | 0.5333333333333333 | 0.4959016393442623 | 0.491869918699187 | 0.5037037037037037 |\n",
      "+------+--------------------+--------------------+-------------------+--------------------+\n",
      "CPD of I:\n",
      "+------+-------+\n",
      "| I(0) | 0.486 |\n",
      "+------+-------+\n",
      "| I(1) | 0.514 |\n",
      "+------+-------+\n",
      "CPD of L:\n",
      "+------+--------------------+--------------------+\n",
      "| G    | G(0)               | G(1)               |\n",
      "+------+--------------------+--------------------+\n",
      "| L(0) | 0.48582995951417   | 0.5355731225296443 |\n",
      "+------+--------------------+--------------------+\n",
      "| L(1) | 0.5141700404858299 | 0.4644268774703557 |\n",
      "+------+--------------------+--------------------+\n",
      "CPD of S:\n",
      "+------+---------------------+--------------------+\n",
      "| I    | I(0)                | I(1)               |\n",
      "+------+---------------------+--------------------+\n",
      "| S(0) | 0.5041152263374485  | 0.4961089494163424 |\n",
      "+------+---------------------+--------------------+\n",
      "| S(1) | 0.49588477366255146 | 0.5038910505836576 |\n",
      "+------+---------------------+--------------------+\n"
     ]
    }
   ],
   "source": [
    "# 定义模型\n",
    "from pgmpy.models import BayesianModel\n",
    "from pgmpy.estimators import MaximumLikelihoodEstimator, BayesianEstimator\n",
    "\n",
    "model = BayesianModel([('D', 'G'), ('I', 'G'), ('I', 'S'), ('G', 'L')])\n",
    "\n",
    "# 基于极大似然估计进行模型训练\n",
    "model.fit(data, estimator=MaximumLikelihoodEstimator)\n",
    "for cpd in model.get_cpds():\n",
    "    # 打印条件概率分布\n",
    "    print(\"CPD of {variable}:\".format(variable=cpd.variable))\n",
    "    print(cpd)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
